package com.test.chatrobot3.test;

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

import java.util.Arrays;
import java.util.List;

public class TextVectorizer {
    private VocabularyBuilder vocab;

    public TextVectorizer(VocabularyBuilder vocab) {
        this.vocab = vocab;
    }

    public INDArray vectorize(List<String> tokens) {
        INDArray vector = Nd4j.zeros(vocab.size());
        for (String token : tokens) {
            int index = vocab.getIndex(token);
            if (index != -1) {
                vector.putScalar(index, 1);
            }
        }
        return vector;
    }

    public static void main(String[] args) {
        VocabularyBuilder vocab = new VocabularyBuilder();
        vocab.addWord("hello");
        vocab.addWord("world");
        vocab.addWord("test");

        TextVectorizer vectorizer = new TextVectorizer(vocab);
        List<String> tokens = Arrays.asList("hello", "world");
        INDArray vector = vectorizer.vectorize(tokens);
        System.out.println(vector);  // 输出: [1.0, 1.0, 0.0]
    }
}